import datetime
import functools
import operator

from django.db.models import Q
from django.template import loader
from django_filters import rest_framework as filters
from rest_framework.filters import OrderingFilter, SearchFilter

from sysreptor.pentests.models import Language, PentestProject, ProjectType, ProjectTypeScope
from sysreptor.pentests.models.archive import ArchivedProject
from sysreptor.utils.utils import parse_date_string


class MultiValueFilter(filters.Filter):
    def _get_params(self):
        request = getattr(self.parent, 'request', None)
        if request:
            for param_name, filter_obj in self.parent.filters.items():
                if filter_obj == self:
                    if values := request.GET.getlist(param_name, []):
                        yield from values


class MultiValueCharFilter(MultiValueFilter, filters.CharFilter):
    def filter(self, qs, value):
        if value in ([], (), {}, None, ''):
            return qs
        if not getattr(self.parent, 'request', None):
            return super().filter(qs, value)

        q_objects = Q()
        for val in self._get_params():
            # For each value, create a Q object
            if self.lookup_expr in ('iexact', 'icontains', 'iregex', 'contains'):
                kwargs = {f"{self.field_name}__{self.lookup_expr}": val}
            elif self.lookup_expr == 'array_contains':
                kwargs = {f"{self.field_name}__contains": [val]}
            else:
                kwargs = {self.field_name: val}

            # Combine with existing conditions using the appropriate operator
            q_objects = q_objects | Q(**kwargs)

        return qs.exclude(q_objects) if self.exclude else qs.filter(q_objects)


class MultiValueTimeRangeFilter(MultiValueFilter, filters.CharFilter):
    def filter(self, qs, value):
        if value in ([], (), {}, None, ''):
            return qs
        if not getattr(self.parent, 'request', None):
            return super().filter(qs, value)

        q_objects = Q()
        for val in self._get_params():
            parts = val.split('|')
            if len(parts) != 2:
                continue
            after, before = parts

            # Handle empty values
            q_obj_after = Q()
            if after and after != 'null':
                try:
                    after_date = parse_date_string(after)
                    q_obj_after = Q(**{f"{self.field_name}__gte": after_date})
                except ValueError:
                    pass

            q_obj_before = Q()
            if before and before != 'null':
                try:
                    before_date = parse_date_string(before)
                    before_date += datetime.timedelta(days=1)  # Add one day to include the end date
                    q_obj_before = Q(**{f"{self.field_name}__lte": before_date})
                except ValueError:
                    pass

            # Combine the time range conditions
            q_obj = q_obj_after & q_obj_before

            # Combine with existing conditions using the appropriate operator
            q_objects = q_objects | q_obj

        return qs.exclude(q_objects) if self.exclude else qs.filter(q_objects)


class PentestProjectFilterSet(filters.FilterSet):
    tag = MultiValueCharFilter(field_name='tags', lookup_expr='array_contains')
    not_tag = MultiValueCharFilter(field_name='tags', lookup_expr='array_contains', exclude=True)
    timerange = MultiValueTimeRangeFilter(field_name='created', lookup_expr='range')
    not_timerange = MultiValueTimeRangeFilter(field_name='created', lookup_expr='range', exclude=True)
    language = MultiValueCharFilter(field_name='language', lookup_expr='exact')
    not_language = MultiValueCharFilter(field_name='language', lookup_expr='exact', exclude=True)
    member = MultiValueCharFilter(field_name='members__user__username', lookup_expr='exact')
    not_member = MultiValueCharFilter(field_name='members__user__username', lookup_expr='exact', exclude=True)

    class Meta:
        model = PentestProject
        fields = ["readonly"]


class ArchivedProjectFilterSet(filters.FilterSet):
    tag = MultiValueCharFilter(field_name='tags', lookup_expr='array_contains')
    not_tag = MultiValueCharFilter(field_name='tags', lookup_expr='array_contains', exclude=True)
    timerange = MultiValueTimeRangeFilter(field_name='created', lookup_expr='range')
    not_timerange = MultiValueTimeRangeFilter(field_name='created', lookup_expr='range', exclude=True)

    class Meta:
        model = ArchivedProject
        fields = []


class FindingTemplateFilter(filters.FilterSet):
    language = MultiValueCharFilter(field_name='translations__language', lookup_expr='exact')
    not_language = MultiValueCharFilter(field_name='translations__language', lookup_expr='exact', exclude=True)
    timerange = MultiValueTimeRangeFilter(field_name='created', lookup_expr='range')
    not_timerange = MultiValueTimeRangeFilter(field_name='created', lookup_expr='range', exclude=True)
    tag = MultiValueCharFilter(field_name='tags', lookup_expr='array_contains')
    not_tag = MultiValueCharFilter(field_name='tags', lookup_expr='array_contains', exclude=True)
    status = MultiValueCharFilter(field_name='translations__status', lookup_expr='exact')
    not_status = MultiValueCharFilter(field_name='translations__status', lookup_expr='exact', exclude=True)
    risk_level = MultiValueCharFilter(field_name='translations__risk_level', lookup_expr='exact')
    not_risk_level = MultiValueCharFilter(field_name='translations__risk_level', lookup_expr='exact', exclude=True)

    preferred_language = filters.ChoiceFilter(choices=Language.choices, method='filter_preferred_language', label='Preferred Language')

    def filter_preferred_language(self, queryset, name, value):
        return queryset.order_by_language(value)


class ProjectTypeFilter(filters.FilterSet):
    language = MultiValueCharFilter(field_name='language', lookup_expr='exact')
    not_language = MultiValueCharFilter(field_name='language', lookup_expr='exact', exclude=True)
    timerange = MultiValueTimeRangeFilter(field_name='created', lookup_expr='range')
    not_timerange = MultiValueTimeRangeFilter(field_name='created', lookup_expr='range', exclude=True)
    tag = MultiValueCharFilter(field_name='tags', lookup_expr='array_contains')
    not_tag = MultiValueCharFilter(field_name='tags', lookup_expr='array_contains', exclude=True)
    status = MultiValueCharFilter(field_name='status', lookup_expr='exact')
    not_status = MultiValueCharFilter(field_name='status', lookup_expr='exact', exclude=True)

    scope = filters.MultipleChoiceFilter(label='Scopes', choices=ProjectTypeScope.choices, method='filter_scopes')
    linked_project = filters.UUIDFilter(label='Linked project', method='filter_linked_project')

    class Meta:
        model = ProjectType
        fields = ['language']

    def filter_scopes(self, queryset, name, value):
        scope_filters = []
        for v in set(value):
            if v == 'global':
                scope_filters.append(Q(linked_project=None) & Q(linked_user=None))
            elif v == 'private':
                scope_filters.append(Q(linked_user=self.request.user))
            elif v == 'project':
                scope_filters.append(Q(linked_project__isnull=False))

        return queryset.filter(functools.reduce(operator.or_, scope_filters))

    def filter_linked_project(self, queryset, name, value):
        return queryset.filter(Q(linked_project=None) | Q(linked_project_id=value))


class ProjectTypeOrderingFilter(OrderingFilter):
    ordering_fields = ['created', 'updated', 'name', 'scope', 'status', 'usage']

    def get_queryset_ordering(self, request, queryset, view):
        out = []
        for o in self.get_ordering(request, queryset, view):
            if o == 'scope':
                out.append('scope_order')
            elif o == '-scope':
                out.append('-scope_order')
            elif o == 'status':
                out.append('status_order')
            elif o == '-status':
                out.append('-status_order')
            elif o == 'usage':
                out.append('usage_count')
            elif o == '-usage':
                out.append('-usage_count')
            else:
                out.append(o)
        return out

    def filter_queryset(self, request, queryset, view):
        ordering = self.get_queryset_ordering(request, queryset, view)
        if ordering:
            if 'scope_order' in ''.join(ordering):
                queryset = queryset.annotate_scope_order()
            if 'status_order' in ''.join(ordering):
                queryset = queryset.annotate_status_order()
            return queryset.order_by(*ordering)
        return queryset

    def get_default_ordering(self, view):
        return ['-created']


class FindingTemplateSearchFilter(SearchFilter):
    def filter_queryset(self, request, queryset, view):
        search_terms = self.get_search_terms(request)
        if not search_terms:
            return queryset

        return queryset \
            .search(search_terms)

    def to_html(self, request, queryset, view):
        context = {
            'param': self.search_param,
            'term': request.query_params.get(self.search_param, ''),
        }
        template = loader.get_template(self.template)
        return template.render(context)


class FindingTemplateOrderingFilter(OrderingFilter):
    ordering_fields = ['created', 'updated', 'risk', 'usage']

    def get_queryset_ordering(self, request, queryset, view):
        ordering_query = self.get_ordering(request, queryset, view)[0]

        # Combine with preferred_language ordering filter
        ordering = []
        match ordering_query.removeprefix('-'):
            case 'risk':
                ordering = ['risk_level_number', 'risk_score_number', 'created']
            case 'usage':
                ordering = ['usage_count', 'risk_level_number', 'risk_score_number', 'created']
            case o if o in ['created', 'updated']:
                ordering = [o]
            case _:
                return None

        # Invert filter if ordering is descending
        if ordering_query.startswith('-'):
            ordering = [f'-{o}' for o in ordering]

        # Combine existing ordering filters for search priority
        existing_ordering = list(queryset.query.order_by)
        if existing_ordering in [['-has_language', '-search_rank'], ['-has_language'], ['-search_rank']]:
            ordering = existing_ordering + ordering

        return ordering

    def filter_queryset(self, request, queryset, view):
        ordering = self.get_queryset_ordering(request, queryset, view)
        if ordering:
            return queryset \
                .annotate_risk_level_number() \
                .order_by(*ordering)
        return queryset

    def get_default_ordering(self, view):
        return ['-risk']
